decl_raw_opr(
    'collective_comm',
    inputs = [
        Doc('input', 'Input var.'),
        Doc('key', 'The key to NCCL cliques. Operators with same key belong '
            'to the same NCCL operation.', 'str'),
        Doc('nr_devices', 'Total number of devices involved in the NCCL '
            'operation to which this operator belongs.', 'int'),
        Doc('rank', 'Rank of this operator', 'int'),
        Doc('root', 'root rank of broadcast or reduce operation'),
        Doc('server_addr', 'rpc server ip address'),
        Doc('port', 'server rpc listening port'),
        Doc('param', 'The only component of *param* is *mode*, which refers to '
            'a specific NCCL operation type.',
            ':class:`~megbrain.opr_param_defs.CollectiveComm`'),
        Doc('dtype', 'Data type of inputs and outputs. Currently this is '
            'required by BROADCAST and optional to other operations. If '
            'specified, it must be consistent with the *dtype* of inputs (if '
            'any).', ':class:`~megbrain.opr_param_defs.DType`', 'None'),
        Doc('backend', 'Backend for collective communication, nccl or ucx',
            'str', '\'nccl\''),
        Doc('output_buffer', 'The external dev buffer reserving output result',
            ':class:`.SharedND`', 'None'),
        Doc('disable', 'If true, the execution will return directly and the output '
            'is a random value. All the disable should be same in one collective '
            'communication group.', ':class:`.SharedScalar`', '_mgb.SharedScalar(0)')
    ],
    body = [
         'if isinstance(input, _mgb.SymbolVar):',
        ('    output = _mgb._Opr.collective_comm_with_input(input, key, '
         'nr_devices, rank, root, server_addr, port, '
         '[param.serialize()], dtype, backend, output_buffer, config, disable)'),
         'else:',
         '    assert isinstance(input, _mgb.CompGraph)',
        ('    output = _mgb._Opr.collective_comm_without_input(input, key, '
         'nr_devices, rank, root, server_addr, port, '
         '[param.serialize()], dtype, backend, output_buffer, config, disable)')
    ],
    desc = ('collective communication between multiple CompNodes on multiple '
            'machines')
)

# vim: ft=python
